// SPDX-License-Identifier: Apache-2.0

package firrtl
package transforms

import firrtl.ir._
import firrtl.Mappers._
import firrtl.annotations._
import scala.collection.mutable
import firrtl.passes.{InlineInstances, PassException}
import firrtl.stage.Forms

/** Tags an annotation to be consumed by this transform */
case class FlattenAnnotation(target: Named) extends SingleTargetAnnotation[Named] {
  def duplicate(n: Named) = FlattenAnnotation(n)
}

/**
  * Takes flatten annotations for module instances and modules and inline the entire hierarchy of
  * modules down from the annotations. This transformation instantiates and is based on the
  * InlineInstances transformation.
  *
  * @note Flattening a module means inlining all its fully-defined child instances
  * @note Instances of extmodules are not (and cannot be) inlined
  */
class Flatten extends Transform with DependencyAPIMigration {

  override def prerequisites = Forms.LowForm
  override def optionalPrerequisites = Seq.empty
  override def optionalPrerequisiteOf = Forms.LowEmitters
  override def invalidates(a: Transform) = false

  val inlineTransform = new InlineInstances

  private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) =
    anns.foldLeft((Set.empty[ModuleName], Set.empty[ComponentName])) {
      case ((modNames, instNames), ann) =>
        ann match {
          case FlattenAnnotation(CircuitName(c)) =>
            (
              circuit.modules.collect {
                case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c))
              }.toSet,
              instNames
            )
          case FlattenAnnotation(ModuleName(mod, cir))    => (modNames + ModuleName(mod, cir), instNames)
          case FlattenAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod))
          case _                                          => throw new PassException("Annotation must be a FlattenAnnotation")
        }
    }

  /**
    *  Modifies the circuit by replicating the hierarchy under the annotated objects (mods and insts) and
    *  by rewriting the original circuit to refer to the new modules that will be inlined later.
    *  @return modified circuit and ModuleNames to inline
    */
  def duplicateSubCircuitsFromAnno(
    c:     Circuit,
    mods:  Set[ModuleName],
    insts: Set[ComponentName]
  ): (Circuit, Set[ModuleName]) = {
    val modMap = c.modules.map(m => m.name -> m).toMap
    val seedMods = mutable.Map.empty[String, String]
    val newModDefs = mutable.Set.empty[DefModule]
    val nsp = Namespace(c)

    /**
      *  We start with rewriting DefInstances in the modules with annotations to refer to replicated modules to be created later.
      *  It populates seedMods where we capture the mapping between the original module name of the instances came from annotation
      *  to a new module name that we will create as a replica of the original one.
      *  Note: We replace old modules with it replicas so that other instances of the same module can be left unchanged.
      */
    def rewriteMod(parent: DefModule)(x: Statement): Statement = x match {
      case _: Block => x.map(rewriteMod(parent))
      case WDefInstance(info, instName, moduleName, instTpe) =>
        if (
          insts.contains(ComponentName(instName, ModuleName(parent.name, CircuitName(c.main))))
          || mods.contains(ModuleName(parent.name, CircuitName(c.main)))
        ) {
          val newModName =
            if (seedMods.contains(moduleName)) seedMods(moduleName) else nsp.newName(moduleName + "_TO_FLATTEN")
          seedMods += moduleName -> newModName
          WDefInstance(info, instName, newModName, instTpe)
        } else x
      case _ => x
    }

    val modifMods = c.modules.map { m => m.map(rewriteMod(m)) }

    /**
      *  Recursively rewrites modules in the hierarchy starting with modules in seedMods (originally annotations).
      *  Populates newModDefs, which are replicated modules used in the subcircuit that we create
      *  by recursively traversing modules captured inside seedMods and replicating them
      */
    def recDupMods(mods: Map[String, String]): Unit = {
      val replMods = mutable.Map.empty[String, String]

      def dupMod(x: Statement): Statement = x match {
        case _: Block => x.map(dupMod)
        case WDefInstance(info, instName, moduleName, instTpe) =>
          modMap(moduleName) match {
            case m: Module =>
              val newModName =
                if (replMods.contains(moduleName)) replMods(moduleName) else nsp.newName(moduleName + "_TO_FLATTEN")
              replMods += moduleName -> newModName
              WDefInstance(info, instName, newModName, instTpe)
            case _ => x // Ignore extmodules
          }
        case _ => x
      }

      def dupName(name: String): String = mods(name)
      val newMods = mods.map { case (origName, newName) => modMap(origName).map(dupMod).map(dupName) }

      newModDefs ++= newMods

      if (replMods.size > 0) recDupMods(replMods.toMap)

    }
    recDupMods(seedMods.toMap)

    //convert newly created modules to ModuleName for inlining next (outside this function)
    val modsToInline = newModDefs.map { m => ModuleName(m.name, CircuitName(c.main)) }
    (c.copy(modules = modifMods ++ newModDefs), modsToInline.toSet)
  }

  override def execute(state: CircuitState): CircuitState = {
    val annos = state.annotations.collect { case a @ FlattenAnnotation(_) => a }
    annos match {
      case Nil => state
      case myAnnotations =>
        val (modNames, instNames) = collectAnns(state.circuit, myAnnotations)
        // take incoming annotation and produce annotations for InlineInstances, i.e. traverse circuit down to find all instances to inline
        val (newc, modsToInline) = duplicateSubCircuitsFromAnno(state.circuit, modNames, instNames)
        inlineTransform.run(newc, modsToInline.toSet, Set.empty[ComponentName], state.annotations)
    }
  }
}
